import math
import numpy as np
import pickle
from scipy import interpolate
import matplotlib as mpl
import matplotlib.pyplot as plt
from scipy.optimize import minimize
from scipy.optimize import brute
from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)
import emcee

#spectral_dict=pickle.load(open("spectral_dict.pickle", "rb"))

def fit_HDO(fitline,spectrum_file, template_file,lines,linefreqs,lineamp,guesses,labels,bounds,rms_range,measure_range, template_file_coms='', useMCMC=True,label_v_offset=3.0,label_y_offset=0.015,plot_title=''):
   panel_label=''
   if 'HDO' in fitline:
      legend_text='HDO'
      if '241' in fitline:
         panel_label='c)'
      elif '225' in fitline:
         panel_label='a)'
   elif 'H218O' in fitline:
      legend_text='H$_2$$^{18}$O'
      panel_label='e)'
   else:
      legend_text=fitline

   plt.rcParams.update({'font.size': 20})
   mpl.rcParams['axes.formatter.useoffset'] = False

   spectrum=np.loadtxt(spectrum_file)
   spectral_dict={}
   spectral_dict[fitline]={'freq_axis': spectrum[:,0]*1.0e6,'spectrum': spectrum[:,1],'e_spectrum': spectrum[:,2]}


   template=np.loadtxt(template_file)
   spectral_template=dict(freq_axis=template[:,0]*1.0e6,spectrum=template[:,1])
   if template_file_coms != '':
      template_coms=np.loadtxt(template_file_coms)
      spectral_template_coms=dict(freq_axis=template_coms[:,0]*1.0e6,spectrum=template_coms[:,1])
      #interpolate COM template to same grid as first template
      spectrum_temp=spectral_template_coms['spectrum'].copy()
      f_temp = interpolate.interp1d(spectral_template_coms['freq_axis'],spectral_template_coms['spectrum'],fill_value=(0,0),bounds_error=False)
      new_template=(f_temp(spectral_template['freq_axis']))
      #setup template dictionary
      spectral_template_coms['freq_axis']=spectral_template['freq_axis'].copy()
      spectral_template_coms['spectrum']=new_template
   else:
      spectral_template_coms=spectral_template

   '''
   spectral_dict=pickle.load(open("spectral_dict.pickle", "rb"))
   spectral_template=pickle.load(open("spectral_template_lime.pickle", "rb"))
   spectral_template_coms=pickle.load(open("spectral_template.pickle", "rb"))
   '''


   freq_axis=spectral_dict[fitline]['freq_axis'].copy()
   #vel_axis=(freq_axis-restfreq_hdo)/restfreq_hdo*3.0e5
   spectrum=spectral_dict[fitline]['spectrum'].copy()
   e_spectrum=spectral_dict[fitline]['e_spectrum'].copy()

   def rms(x, axis=None):
       return np.sqrt(np.mean(x**2, axis=axis))
   def measure_rms(rms_range):
      rmsindex=((freq_axis < rms_range[0]) | (freq_axis > rms_range[1])).nonzero()
      rmsflux=rms(residual_spectrum[rmsindex[0]])
      return rmsflux

   def measure_flux_residuals(measure_range,bestfits):
      index=((freq_axis > measure_range[0]) & (freq_axis < measure_range[1])).nonzero()
      hdo_flux=np.sum(plot_com(bestfits)[1][index[0]]) # get results from plot_com, which is just line of interest
      residual_spectrum=plot_spectra(bestfits)[1]
      rms_resid_hdo=rms(residual_spectrum[index[0]])
      hdo_e_flux=rms_resid_hdo*(len(index[0]))**0.5
      return hdo_flux,hdo_e_flux

   def measure_flux_random(measure_range, bestfits):
      nsamples=5
      hdo_fluxes=np.zeros(nsamples)
      hdo_e_fluxes=np.zeros(nsamples)
      hdo_model_fluxes=np.zeros(nsamples)
      velocities = np.abs(bestfits[0]) + 0.05*np.abs(bestfits[0]) * np.random.randn(nsamples)
      for i in range(len(velocities)):
         boundsnew=tuple( tuple([(velocities[i],velocities[i])]) +bounds[1:])
         guessesnew=[velocities[i]]+guesses[1:]
         result_new = minimize(nll, guessesnew,args=(spectrum,e_spectrum),method='trust-constr',bounds=boundsnew)
         newfits=result_new['x']
         #print(newfits,boundsnew,result_new['x'])
         hdo_model_fluxes[i]=newfits[1]
         hdo_fluxes[i],hdo_e_fluxes[i]=measure_flux_residuals(measure_range,newfits)
      return hdo_fluxes,hdo_e_fluxes,hdo_model_fluxes

   def apply_templates(x):
      freq_axis_lines=[]
      spectrum_lines=[]
      for i in range(len(linefreqs)):
         if ('HDO' in lines[i]) or ('H$_2$$^{18}$O' in lines[i]):
           template=spectral_template
         else:
            template=spectral_template_coms
         freq_axis_lines.append(template['freq_axis']+linefreqs[i]-x/3.0e5*linefreqs[i])
         spectrum_temp=template['spectrum'].copy()
         f_temp = interpolate.interp1d(freq_axis_lines[i], spectrum_temp,fill_value=(0,0),bounds_error=False)
         spectrum_lines.append(f_temp(freq_axis))
         spectrum_lines[i]=spectrum_lines[i]/np.sum(spectrum_lines[i])
      return freq_axis_lines,spectrum_lines

   def plot_spectra(x):
      freq_axis_lines, spectrum_lines=apply_templates(x[0])

      comb_spectrum=np.zeros(len(freq_axis))
      for i in range(len(linefreqs)):
         comb_spectrum+=spectrum_lines[i]*x[i+1]
      diff_spectrum= spectrum-comb_spectrum
      return [comb_spectrum,diff_spectrum]

   def plot_hdo(x):
      freq_axis_lines, spectrum_lines=apply_templates(x[0])
      comb_spectrum=np.zeros(len(freq_axis))
      for i in range(len(linefreqs)):
         if ('HDO' in lines[i]) or ('H$_2$$^{18}$O' in lines[i]):
            spectrum_lines[i]=spectrum_lines[i]*x[i+1]
            comb_spectrum+=spectrum_lines[i]

      diff_spectrum= spectrum-comb_spectrum
      return [comb_spectrum,diff_spectrum]

   def plot_com(x):
      freq_axis_lines, spectrum_lines=apply_templates(x[0])
      comb_spectrum=np.zeros(len(freq_axis))
      for i in range(len(linefreqs)):
         if ('HDO' not in lines[i]) and ('H$_2$$^{18}$O' not in lines[i]):
            spectrum_lines[i]=spectrum_lines[i]*x[i+1]
            comb_spectrum+=spectrum_lines[i]

      diff_spectrum= spectrum-comb_spectrum
      return [comb_spectrum,diff_spectrum]

   def get_line_flux(x):
      freq_axis_lines, spectrum_lines=apply_templates(x[0])
      line_fluxes=[]
      for i in range(len(linefreqs)):
         spectrum_lines[i]=spectrum_lines[i]*x[i+1]
         line_fluxes.append(np.sum(spectrum_lines[i]))

      return line_fluxes

   def lnprior(p):
       outside=False
       for i in range(len(p)):
          if not bounds[i][0] < p[i] < bounds[i][1]:
             outside=True
       if not outside:
          return 0.0
       else:
          return -np.inf

   def lnlike(p, spectrum,e_spectrum):
       freq_axis_lines, spectrum_lines=apply_templates(p[0])
       model=np.zeros(len(freq_axis))
       for i in range(1,len(p)):
          model += spectrum_lines[i-1]*p[i]
       # the likelihood is sum of the lot of normal distributions
       sigma2=e_spectrum**2 
       lp= -0.5*np.sum((spectrum-model)**2/sigma2**2) 
       return lp

   def lnprob(p, spectrum,e_spectrum):
       lp = lnprior(p)
       if not np.isfinite(lp):
           return -np.inf
       return lp + lnlike(p, spectrum,e_spectrum)

   nll = lambda *args: -lnlike(*args)
   
   result = minimize(nll, guesses,args=(spectrum,e_spectrum),method='trust-constr',bounds=bounds) 
   print(result['x'])
   

   
   if useMCMC:
      pos = np.abs(result.x) + 0.1*np.abs(result.x) * np.random.randn(100, len(linefreqs)+1)
      nwalkers, ndim = pos.shape

      #from multiprocessing import Pool
      #with Pool() as pool:
      
      sampler = emcee.EnsembleSampler(nwalkers, ndim, lnprob, args=(spectrum,e_spectrum))
      sampler.run_mcmc(pos, 1500, progress=True)

      fig, axes = plt.subplots(ndim, figsize=(10, len(linefreqs)+1), sharex=True)
      samples = sampler.get_chain()
 
      for i in range(ndim):
          ax = axes[i]
          ax.plot(samples[:, :, i], "k", alpha=0.3)
          ax.set_xlim(0, len(samples))
          ax.set_ylabel(labels[i])
          ax.yaxis.set_label_coords(-0.1, 0.5)

      axes[-1].set_xlabel("step number");
      plt.savefig(fitline+'_walkers.png')

      flat_samples = sampler.get_chain(discard=500, thin=15, flat=True)
      print(flat_samples.shape)

      import corner
      fig = corner.corner(
          flat_samples, labels=labels) #, truths=[m_true, b_true, np.log(f_true)]);
      plt.savefig(fitline+'_corner.png')

      bestfits=np.zeros(ndim)
      unc=np.zeros((ndim,2))
      for i in range(ndim):
         bestfits[i]=np.median(flat_samples[:,i])
         unc[i]=np.percentile(flat_samples[:,i],[16, 84])-bestfits[i]

      print(bestfits)
      print(unc)
   else:
      bestfits=result['x']
      print(bestfits)
   line_fluxes=get_line_flux(bestfits)
   residual_spectrum=plot_spectra(bestfits)[1]

   fig,ax=plt.subplots(1,1,figsize=(15, 10))
   ax.plot(freq_axis/1e9,spectrum,drawstyle='steps',linewidth='2',color='black',label='Full Spectrum')

   ax.plot(freq_axis/1e9,plot_hdo(bestfits)[0],drawstyle='steps',linewidth='2',label=legend_text+' Model')
   ax.plot(freq_axis/1e9,plot_com(bestfits)[0],drawstyle='steps',linewidth='2',label='COM Model')
   ax.plot(freq_axis/1e9,plot_com(bestfits)[0]+plot_hdo(bestfits)[0],drawstyle='steps',linewidth='2',label=legend_text+' + COM Model')
   ax.legend(fontsize=20,loc='upper left')
   ax.xaxis.set_major_locator(MultipleLocator(0.005))
   ax.xaxis.set_minor_locator(AutoMinorLocator())
   ax.tick_params(which='both', width=2)
   ax.tick_params(which='major', length=8)
   ax.tick_params(which='minor', length=4)
   ax.set_ylabel('Flux Density (Jy)')
   ax.set_xlabel('Frequency (GHz)')
   ax.set_xlim(np.min(freq_axis/1e9),np.max(freq_axis/1e9))
   ax.set_ylim(np.min(spectrum)-0.05*np.min(spectrum),np.max(spectrum)+0.2*np.max(spectrum))
   for i in range(len(lines)):
      ax.text(linefreqs[i]/1e9-(bestfits[0]+label_v_offset)/3.0e5*linefreqs[i]/1e9,lineamp[i]+label_y_offset,lines[i])
      ax.text(linefreqs[i]/1e9-bestfits[0]/3.0e5*linefreqs[i]/1e9,lineamp[i],'|')
   ax.set_title(plot_title)
   ax.text(np.max(freq_axis/1e9)-0.000005*np.max(freq_axis/1e9),plt.ylim()[1]-plt.ylim()[1]*0.05,panel_label)

   plt.savefig(fitline+'_best_fit_combined.png')
   plt.savefig(fitline+'_best_fit_combined.pdf')

   fig,ax=plt.subplots(1,1,figsize=(15, 10))
   ax.plot(freq_axis/1e9,spectrum,drawstyle='steps',linewidth='2',color='black',label='Full Spectrum')
   ax.plot(freq_axis/1e9,plot_com(bestfits)[0],drawstyle='steps',linewidth='2',label='COM Model')
   ax.plot(freq_axis/1e9,plot_com(bestfits)[1],drawstyle='steps',linewidth='2',label='COM Residual')
   ax.legend(fontsize=20)
   ax.xaxis.set_major_locator(MultipleLocator(0.005))
   ax.xaxis.set_minor_locator(AutoMinorLocator())
   ax.tick_params(which='both', width=2)
   ax.tick_params(which='major', length=8)
   ax.tick_params(which='minor', length=4)
   ax.set_ylabel('Flux Density (Jy)')
   ax.set_xlabel('Frequency (GHz)')
   ax.set_xlim(np.min(freq_axis/1e9),np.max(freq_axis/1e9))
   ax.set_ylim(np.min(spectrum)-0.05*np.min(spectrum),np.max(spectrum)+0.2*np.max(spectrum))
   for i in range(len(lines)):
      ax.text(linefreqs[i]/1e9-(bestfits[0]+label_v_offset)/3.0e5*linefreqs[i]/1e9,lineamp[i]+label_y_offset,lines[i])
      ax.text(linefreqs[i]/1e9-bestfits[0]/3.0e5*linefreqs[i]/1e9,lineamp[i],'|')
   plt.savefig(fitline+'_best_fit_COM_sub.png')
   plt.savefig(fitline+'_best_fit_COM_sub.pdf')

   fig,ax=plt.subplots(1,1,figsize=(15, 10))
   ax.plot(freq_axis/1e9,spectrum,drawstyle='steps',linewidth='2',color='black',label='Full Spectrum')
   ax.plot(freq_axis/1e9,plot_hdo(bestfits)[0],drawstyle='steps',linewidth='2',label=legend_text+' Model')
   ax.plot(freq_axis/1e9,plot_hdo(bestfits)[1],drawstyle='steps',linewidth='2',label=legend_text+' Residual')
   ax.legend(fontsize=20)
   ax.xaxis.set_major_locator(MultipleLocator(0.005))
   ax.xaxis.set_minor_locator(AutoMinorLocator())
   ax.tick_params(which='both', width=2)
   ax.tick_params(which='major', length=8)
   ax.tick_params(which='minor', length=4)
   ax.set_ylabel('Flux Density (Jy)')
   ax.set_xlabel('Frequency (GHz)')
   ax.set_xlim(np.min(freq_axis/1e9),np.max(freq_axis/1e9))
   ax.set_ylim(np.min(spectrum)-0.05*np.min(spectrum),np.max(spectrum)+0.2*np.max(spectrum))
   for i in range(len(lines)):
      ax.text(linefreqs[i]/1e9-(bestfits[0]+label_v_offset)/3.0e5*linefreqs[i]/1e9,lineamp[i]+label_y_offset,lines[i])
      ax.text(linefreqs[i]/1e9-bestfits[0]/3.0e5*linefreqs[i]/1e9,lineamp[i],'|')
   plt.savefig(fitline+'_best_fit_HDO_sub.png')
   plt.savefig(fitline+'_best_fit_HDO_sub.pdf')

   fig,ax=plt.subplots(1,1,figsize=(15, 10))
   ax.plot(freq_axis/1e9,plot_spectra(bestfits)[1],drawstyle='steps',linewidth='2',label='Residual')
   ax.plot(freq_axis/1e9,plot_com(bestfits)[0]+plot_hdo(bestfits)[0],drawstyle='steps',linewidth='2',label=legend_text+' + COM Model')
   ax.plot(freq_axis/1e9,spectrum,drawstyle='steps',linewidth='2',color='black',label='Full Spectrum')
   ax.legend(fontsize=20)
   ax.xaxis.set_major_locator(MultipleLocator(0.005))
   ax.xaxis.set_minor_locator(AutoMinorLocator())
   ax.tick_params(which='both', width=2)
   ax.tick_params(which='major', length=8)
   ax.tick_params(which='minor', length=4)
   ax.set_ylabel('Flux Density (Jy)')
   ax.set_xlabel('Frequency (GHz)')
   ax.set_xlim(np.min(freq_axis/1e9),np.max(freq_axis/1e9))
   ax.set_ylim(np.min(spectrum)-0.05*np.min(spectrum),np.max(spectrum)+0.2*np.max(spectrum))
   for i in range(len(lines)):
      ax.text(linefreqs[i]/1e9-(bestfits[0]+label_v_offset)/3.0e5*linefreqs[i]/1e9,lineamp[i]+label_y_offset,lines[i])
      ax.text(linefreqs[i]/1e9-bestfits[0]/3.0e5*linefreqs[i]/1e9,lineamp[i],'|')
   plt.savefig(fitline+'_best_fit_residual.png')
   plt.savefig(fitline+'_best_fit_residual.pdf')

   print(get_line_flux(bestfits))
   rmsflux=measure_rms(rms_range)
   hdo_flux,hdo_e_flux=measure_flux_residuals(measure_range, bestfits)
   
   #MCMC FLUX ERROR COMMENTED BECAUSE IT CHANGES
   #hdo_fluxes_random,hdo_e_fluxes_random,hdo_model_fluxes_random=measure_flux_random(measure_range, bestfits)
   #print(hdo_fluxes_random,hdo_e_fluxes_random,hdo_model_fluxes_random)
   #hdo_e_flux_mcmc=np.max(np.abs(np.percentile(hdo_fluxes_random,[16,84])-np.median(hdo_fluxes_random)))
   #print('measured flux and uncert',np.median(hdo_fluxes_random),np.percentile(hdo_fluxes_random,[16,84])-np.median(hdo_fluxes_random))
   #print('model flux and uncert',np.median(hdo_model_fluxes_random),np.percentile(hdo_model_fluxes_random,[16,84])-np.median(hdo_model_fluxes_random))
   #if hdo_e_flux_mcmc > hdo_e_flux:
   #   hdo_e_flux=hdo_e_flux_mcmc
   #   print('USING MCMC FLUX ERROR')
   print('Spectrum RMS '+fitline+' {:0.3f}'.format(rmsflux))
   print('Line Flux '+fitline+' (from spectrum): {:0.2f} +/- {:0.2f}'.format(hdo_flux,hdo_e_flux))
   

   for i in range(len(bestfits)-1):
      print('Line Flux {} (from model): {:0.2f} '.format(lines[i],bestfits[i+1]))
   return hdo_flux,hdo_e_flux

'''
fitline='HDO_225'

restfreq_hdo=225.89672e9
restfreq_ch3ocho=225.90074e9
restfreq_13ch3cho=225.89294840e9
linefreqs=[restfreq_hdo,restfreq_ch3ocho,restfreq_13ch3cho]
lines=['           HDO','    CH$_3$OCHO','$^{13}$CH$_3$CHO']
lineamp=[0.14,0.1,0.1]

v_start=4.3
a_1_s=3.4
a_2_s=2.4
a_3_s=0.5
labels=['v','a1','a2','a3']
guesses=[v_start,a_1_s,a_2_s,a_3_s]
bounds=((4.3,4.3),(0.0,10.0),(0.0,10.0),(0.0,10.0))
rms_range=[225.89e9,225.9e9]
measure_range=[225.89e9,225.8975e9]
fit_HDO(fitline,'HDO/HDO_225_text_spectra.txt','HDO/H218O_text_spectral_template.txt',
            lines,linefreqs,lineamp,guesses,labels,bounds,rms_range,measure_range, template_file_coms='',useMCMC=False)


fitline='HDO_241'
restfreq_hdo=241.56155e9
restfreq_ch3cho=241.56322770e9
restfreq_ch32co=241.55988690e9
linefreqs=[restfreq_hdo,restfreq_ch3cho,restfreq_ch32co]
lines=['             HDO','                CH$_3$CHO','(CH$_3$)$_2$CO']
lineamp=[0.19,0.19,0.13]
v_start=4.3
a_1_s=3.0
a_2_s=1.5
a_3_s=0.5
labels=['v','a1','a2','a3']
guesses=[v_start,a_1_s,a_2_s,a_3_s]
bounds=((4.3,4.3),(0.0,10.0),(0.0,10.0),(0.0,10.0))
rms_range=[241.555e9,241.564e9]
measure_range=[241.555e9,241.562e9]
fit_HDO(fitline,'HDO/HDO_241_text_spectra.txt','HDO/H218O_text_spectral_template.txt',
            lines,linefreqs,lineamp,guesses,labels,bounds,rms_range,measure_range, template_file_coms='',useMCMC=False)

fitline='H218O_203'

restfreq_h218o=203.40752000e9
restfreq_ch3och3_1=203.41011260e9
restfreq_ch3och3_2=203.41143090e9
restfreq_ch3och3_3=203.40277790e9

linefreqs=[restfreq_h218o,restfreq_ch3och3_1,restfreq_ch3och3_2,restfreq_ch3och3_3]
lines=['H$_2$$^{18}$O','CH$_3$OCH$_3$','','']
lineamp=[0.06,0.12,0.12,0.12]

v_start=4.3
a_start=[0.5,1.0,1.0,1.0]
a_1_s=0.1
a_2_s=0.1
a_3_s=0.1
a_4_s=0.1
labels=['v','a1','a2','a3','a4']
guesses=[v_start,a_1_s,a_2_s,a_3_s,a_4_s]
bounds=((4.3,4.3),(0.0,10.0),(0.0,10.0),(0.0,10.0),(0.0,10.0))
rms_range=[203.401e9,203.408e9]
measure_range=[203.401e9,203.408e9]
fit_HDO(fitline,'H218O/H218O_203_text_spectra.txt','HDO/H218O_text_spectral_template.txt',
            lines,linefreqs,lineamp,guesses,labels,bounds,rms_range,measure_range, template_file_coms='',useMCMC=False)

'''
